{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset-bpe.json') as fopen:\n",
    "    data = json.load(fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X = data['train_X']\n",
    "train_Y = data['train_Y']\n",
    "test_X = data['test_X']\n",
    "test_Y = data['test_Y']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "EOS = 2\n",
    "GO = 1\n",
    "vocab_size = 32000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_Y = [i + [2] for i in train_Y]\n",
    "test_Y = [i + [2] for i in test_Y]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import beam_search\n",
    "\n",
    "def pad_second_dim(x, desired_size):\n",
    "    padding = tf.tile([[[0.0]]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1], tf.shape(x)[2]], 0))\n",
    "    return tf.concat([x, padding], 1)\n",
    "\n",
    "class Translator:\n",
    "    def __init__(self, size_layer, num_layers, embedded_size, learning_rate):\n",
    "        \n",
    "        def cells(reuse=False):\n",
    "            return tf.nn.rnn_cell.LSTMCell(size_layer,initializer=tf.orthogonal_initializer(),reuse=reuse)\n",
    "        \n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.Y = tf.placeholder(tf.int32, [None, None])\n",
    "        \n",
    "        self.X_seq_len = tf.count_nonzero(self.X, 1, dtype = tf.int32)\n",
    "        self.Y_seq_len = tf.count_nonzero(self.Y, 1, dtype = tf.int32)\n",
    "        batch_size = tf.shape(self.X)[0]\n",
    "        \n",
    "        embeddings = tf.Variable(tf.random_uniform([vocab_size, embedded_size], -1, 1))\n",
    "        \n",
    "        _, encoder_state = tf.nn.dynamic_rnn(\n",
    "            cell = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)]), \n",
    "            inputs = tf.nn.embedding_lookup(embeddings, self.X),\n",
    "            sequence_length = self.X_seq_len,\n",
    "            dtype = tf.float32)\n",
    "        main = tf.strided_slice(self.Y, [0, 0], [batch_size, -1], [1, 1])\n",
    "        decoder_input = tf.concat([tf.fill([batch_size, 1], GO), main], 1)\n",
    "        dense = tf.layers.Dense(vocab_size)\n",
    "        decoder_cells = tf.nn.rnn_cell.MultiRNNCell([cells() for _ in range(num_layers)])\n",
    "        \n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(\n",
    "                inputs = tf.nn.embedding_lookup(embeddings, decoder_input),\n",
    "                sequence_length = self.Y_seq_len,\n",
    "                time_major = False)\n",
    "        training_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = training_helper,\n",
    "                initial_state = encoder_state,\n",
    "                output_layer = dense)\n",
    "        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = training_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = tf.reduce_max(self.Y_seq_len))\n",
    "        self.training_logits = training_decoder_output.rnn_output\n",
    "        \n",
    "        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(\n",
    "                embedding = embeddings,\n",
    "                start_tokens = tf.tile(tf.constant([GO], dtype=tf.int32), [batch_size]),\n",
    "                end_token = EOS)\n",
    "        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(\n",
    "                cell = decoder_cells,\n",
    "                helper = predicting_helper,\n",
    "                initial_state = encoder_state,\n",
    "                output_layer = dense)\n",
    "        predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(\n",
    "                decoder = predicting_decoder,\n",
    "                impute_finished = True,\n",
    "                maximum_iterations = 2 * tf.reduce_max(self.X_seq_len))\n",
    "        self.fast_result = predicting_decoder_output.sample_id\n",
    "        \n",
    "        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)\n",
    "        self.cost = tf.contrib.seq2seq.sequence_loss(logits = self.training_logits,\n",
    "                                                     targets = self.Y,\n",
    "                                                     weights = masks)\n",
    "        self.optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate).minimize(self.cost)\n",
    "        y_t = tf.argmax(self.training_logits,axis=2)\n",
    "        y_t = tf.cast(y_t, tf.int32)\n",
    "        self.prediction = tf.boolean_mask(y_t, masks)\n",
    "        mask_label = tf.boolean_mask(self.Y, masks)\n",
    "        correct_pred = tf.equal(self.prediction, mask_label)\n",
    "        correct_index = tf.cast(correct_pred, tf.float32)\n",
    "        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "size_layer = 512\n",
    "num_layers = 2\n",
    "embedded_size = 256\n",
    "learning_rate = 1e-3\n",
    "batch_size = 128\n",
    "epoch = 20"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n",
      "WARNING:tensorflow:From <ipython-input-7-52c6c60cd4f9>:11: LSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n",
      "WARNING:tensorflow:From <ipython-input-7-52c6c60cd4f9>:23: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.\n",
      "WARNING:tensorflow:From <ipython-input-7-52c6c60cd4f9>:26: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:958: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.add_weight` method instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn_cell_impl.py:962: calling Zeros.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Call initializer instance with the dtype argument instead of passing it to the constructor\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/ops/rnn.py:244: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Translator(size_layer, num_layers, embedded_size, learning_rate)\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "pad_sequences = tf.keras.preprocessing.sequence.pad_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([[30862, 31330,  7563,  6904,  6904, 14420, 14420,  5320,  3882,\n",
       "          3882, 25569,  4014,  4014,  5510,  2817,  2817,  2817,  8488,\n",
       "         11273,  5878,  5637,  1377,  8304,  8304,  5204,  5204, 15208,\n",
       "         12328, 12328, 31810, 31810, 31810, 20439, 19659, 19659,  4132,\n",
       "          4132, 25043,  1482,  1482, 20925, 20925, 20925, 15262, 15262,\n",
       "         15262, 15262, 15262,   985,   985,   985,   985,   985, 17364,\n",
       "         17364, 21699, 21699, 21699, 17364, 15097, 15097, 22568, 31795,\n",
       "         31795, 31510, 31510, 10779,  6458,  6458,   145,   145,  5615],\n",
       "        [13133, 23439, 23439, 18141,  4568,  4568, 13649, 31341,   808,\n",
       "           808, 11462, 11462,  1480,  1480,  1480, 27723, 27723, 27723,\n",
       "         13440, 16353, 16353, 16353, 20605,  4837,  6188,  6188,  6188,\n",
       "          3888, 24510, 24510,   714,   714,   809,   809,  4190, 19513,\n",
       "         19513, 19513,  7917,  7917,  7917,  8973,  3280,  3280, 24510,\n",
       "         24510, 24510, 25452,  7617,  7617,  7617, 26629,   993, 29224,\n",
       "         29224, 17531, 17531, 17531, 17531, 16739, 16739, 16739, 25561,\n",
       "         16774, 24148, 24148, 24148,  7669,  7669,  7669,  7669,  7669],\n",
       "        [10387, 15069, 14202, 26205, 26205, 27461, 17987, 15172, 17987,\n",
       "         10596,  4460,  1914, 12050, 12050, 30584,  8982,  9221,  9221,\n",
       "          9221,  9221, 26518, 26518, 26518, 19941, 22487, 22487, 22487,\n",
       "          2744,  2744,  2744,  2744,  2744, 11374, 19791, 19791, 19791,\n",
       "          8720,  8720, 26791,  1952, 25679, 14608, 14608, 12115, 17667,\n",
       "         10218, 10218, 18171, 18171, 20997, 29896, 30158, 30158, 30158,\n",
       "         30158, 10967, 10967,   349,   349,   349,   349, 20231, 20231,\n",
       "         20231, 22875,  4251, 26256, 26256,  4251, 26256, 25270, 30956],\n",
       "        [17129, 25926, 10990, 21443, 12768, 12768, 12768, 17867,  2925,\n",
       "          2925,  2925, 30130, 30130, 16013,  5830,  5830,  4894,  4894,\n",
       "          4894, 20143, 20143, 20143, 26802,  8537,  8537,  8537, 20865,\n",
       "         20865,  2439, 28033, 28033, 25857, 25857, 25857, 25857, 30418,\n",
       "         11507, 11507, 24171,  1245, 14881, 14881, 14881,  1646,  1646,\n",
       "         15021, 18507, 18507, 22884, 22884, 22884, 21795, 20631, 21795,\n",
       "         14721,  2768,  2768,  2768,  2768, 10262, 10262, 10262, 10262,\n",
       "          7427,  7427,  7427, 29551,  4018,  4018, 25970, 25970, 15114],\n",
       "        [ 6164, 18130, 12590, 12590, 22029, 22029, 14948,  7164, 10457,\n",
       "          7164,  8252, 14881, 14881, 15305, 15305,  1646,  1646,  9217,\n",
       "          9217, 30041, 30041, 30041, 23367, 23367, 23367, 12176, 12176,\n",
       "         12176, 12176, 12176, 12176, 27389, 27389, 27389, 13683, 13683,\n",
       "         25753, 25753, 25753, 25753,  3222,  3222,  3222, 31032, 31032,\n",
       "         31032,  9520,  5513,  5513,  5513, 30051, 30051, 30051, 29008,\n",
       "         29008, 30051, 19999, 12553, 15369, 12553, 27068, 27068, 15771,\n",
       "         15771, 15771, 15771, 15771, 24180, 13710, 13710, 13710, 13710],\n",
       "        [ 6720, 14590, 14590, 23462, 23462, 30270, 30270, 30270, 30270,\n",
       "          9655,  9655,  9655,  3925,  3925, 23030, 10941, 10941, 10115,\n",
       "         10115, 22491,    94,  9734,  8718,  8718, 13733, 27096, 27096,\n",
       "         27096, 12767, 12767,  2167, 23666, 23666, 14726, 14726,  4900,\n",
       "           958,   958,   958, 11290,   958,  5010,  5010, 19747,  5267,\n",
       "          5267,  5267, 31828,  1623,  1623,  1623, 29611, 29611, 23917,\n",
       "         23917, 23360, 23360, 18726, 18726, 19774, 19774,  5943,  5943,\n",
       "           221,   221,   221,  8403,  8403,  8403,  3926, 10207, 17169],\n",
       "        [18851, 11340, 25700, 25700, 29668, 30601,  4782, 24464, 10815,\n",
       "         10815, 10815, 10815,   876, 16780, 16780, 16780, 16780,  8640,\n",
       "          8640, 25856, 27461, 27461,  2104,  2104,  2104, 26027, 18308,\n",
       "         18308,   676,   676,   676,  1298,  1298,  1298,  1298, 28409,\n",
       "         28409, 28409, 24181, 24181, 14966, 14966, 14966, 14966, 27829,\n",
       "         27829,  1617,  3132,  3132, 25578, 25578, 11490, 28680, 28680,\n",
       "         30474,  9796,  9796,  1281, 16502, 19883, 19883, 19571, 19571,\n",
       "         19571, 19571, 22521, 22521, 22521, 23214, 23214, 23853, 20656],\n",
       "        [ 3058, 14447, 26205, 25019,  4542,  4542, 31133, 31133, 20898,\n",
       "         20898, 20898,  7916, 10869, 10869, 10869, 31337, 10869,  7399,\n",
       "          7399,  1692, 20518,  6416,  6416, 23824, 23824,  9857,  9857,\n",
       "          9857,  9857, 21825,  6794,  6794, 10779, 10779, 10121, 10121,\n",
       "         10121, 10121, 30109, 30109, 30109, 11582, 29442, 29442,  3128,\n",
       "          3128,  7128,  7128, 25740, 25740, 25740,  5143,  7755,  7755,\n",
       "          7755,   630, 31952,  7755,  9976,  3878,  3878,  3878,  3878,\n",
       "          3878,  3878, 23798, 23798, 15016, 23798,  1306,  1306,  4477],\n",
       "        [13597, 13597, 13597, 13597, 20651, 20651, 20651, 20651, 24766,\n",
       "         31083,  7296,  7296,  7296,  7296,  7296,  7296, 17042, 17042,\n",
       "          5782, 20873, 20873, 20873, 20873, 20873, 11467, 11467, 27926,\n",
       "         27926, 27926, 27926, 12095, 12095,  8448, 30264, 30264, 10778,\n",
       "         13596, 10778, 24407, 24407, 23176, 18902, 18902, 18902, 30810,\n",
       "         30810, 24019, 24019,  3817, 24724, 24724, 24437, 21076, 21076,\n",
       "         28865, 28865,  7518, 30761, 30761,  4190, 14074,   651,   651,\n",
       "          7242,   809,  9368, 17583,  7320, 11770, 11770, 13779, 12818],\n",
       "        [31185, 31185,  3062, 17654, 17654, 22079, 24136,  1765, 21788,\n",
       "         21788, 13298, 30743, 30743, 30743, 12946, 30190, 30190,  4244,\n",
       "          4244,  4244, 22705, 22705, 31853, 31853, 31853, 22705,  5248,\n",
       "          5248, 23739, 19055, 19055, 19055,  5248, 20593, 20593, 20593,\n",
       "         14267, 14267, 29946, 23329, 23329,  4287,  2133,  2133,  9073,\n",
       "         15932, 29761, 29761, 29761, 29761, 16181, 16181, 25048, 25048,\n",
       "         17627, 17627,  9932,  7953,  1529,  1529,  1529, 22211, 20732,\n",
       "         27421, 14005, 14005, 16275, 16275, 12778, 12778, 19235,  7967]],\n",
       "       dtype=int32), 10.373379, 0.005076142]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_x = pad_sequences(train_X[:10], padding='post')\n",
    "batch_y = pad_sequences(train_Y[:10], padding='post')\n",
    "\n",
    "sess.run([model.fast_result, model.cost, model.accuracy], \n",
    "         feed_dict = {model.X: batch_x, model.Y: batch_y})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [15:45<00:00,  1.65it/s, accuracy=0.208, cost=5.16]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.88it/s, accuracy=0.21, cost=4.57] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, training avg loss 6.248373, training avg acc 0.146723\n",
      "epoch 1, testing avg loss 5.000242, testing avg acc 0.223035\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:13<00:00,  1.61it/s, accuracy=0.275, cost=4.4] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:11<00:00,  3.57it/s, accuracy=0.301, cost=3.94]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, training avg loss 4.603188, training avg acc 0.254305\n",
      "epoch 2, testing avg loss 4.322093, testing avg acc 0.279571\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:15<00:00,  1.60it/s, accuracy=0.322, cost=3.84]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:14<00:00,  2.76it/s, accuracy=0.344, cost=3.58]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, training avg loss 4.025585, training avg acc 0.305634\n",
      "epoch 3, testing avg loss 3.974566, testing avg acc 0.314820\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:13<00:00,  1.61it/s, accuracy=0.374, cost=3.37]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:14<00:00,  2.82it/s, accuracy=0.382, cost=3.36]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, training avg loss 3.634333, training avg acc 0.345598\n",
      "epoch 4, testing avg loss 3.778099, testing avg acc 0.337871\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:13<00:00,  1.60it/s, accuracy=0.414, cost=3.01]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.87it/s, accuracy=0.387, cost=3.29]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, training avg loss 3.339964, training avg acc 0.378003\n",
      "epoch 5, testing avg loss 3.673842, testing avg acc 0.350452\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:30<00:00,  1.58it/s, accuracy=0.457, cost=2.69]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.96it/s, accuracy=0.382, cost=3.27]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 6, training avg loss 3.105059, training avg acc 0.405415\n",
      "epoch 6, testing avg loss 3.626813, testing avg acc 0.356313\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:12<00:00,  1.61it/s, accuracy=0.497, cost=2.43]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.88it/s, accuracy=0.398, cost=3.23]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 7, training avg loss 2.910636, training avg acc 0.429180\n",
      "epoch 7, testing avg loss 3.618793, testing avg acc 0.359616\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:13<00:00,  1.61it/s, accuracy=0.524, cost=2.21]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.87it/s, accuracy=0.387, cost=3.24]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 8, training avg loss 2.744472, training avg acc 0.451020\n",
      "epoch 8, testing avg loss 3.630387, testing avg acc 0.361000\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:12<00:00,  1.61it/s, accuracy=0.56, cost=2.02] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.90it/s, accuracy=0.387, cost=3.3] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 9, training avg loss 2.597390, training avg acc 0.471465\n",
      "epoch 9, testing avg loss 3.659572, testing avg acc 0.362244\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:20<00:00,  1.59it/s, accuracy=0.598, cost=1.86]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.93it/s, accuracy=0.366, cost=3.3] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, training avg loss 2.467286, training avg acc 0.490605\n",
      "epoch 10, testing avg loss 3.694587, testing avg acc 0.359903\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:06<00:00,  1.62it/s, accuracy=0.624, cost=1.72]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.89it/s, accuracy=0.36, cost=3.35] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 11, training avg loss 2.349300, training avg acc 0.508146\n",
      "epoch 11, testing avg loss 3.738848, testing avg acc 0.355326\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:12<00:00,  1.61it/s, accuracy=0.636, cost=1.58]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.87it/s, accuracy=0.349, cost=3.38]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 12, training avg loss 2.240449, training avg acc 0.525287\n",
      "epoch 12, testing avg loss 3.793463, testing avg acc 0.353127\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:13<00:00,  1.60it/s, accuracy=0.667, cost=1.45]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.91it/s, accuracy=0.376, cost=3.41]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 13, training avg loss 2.140891, training avg acc 0.541594\n",
      "epoch 13, testing avg loss 3.859928, testing avg acc 0.352772\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:12<00:00,  1.61it/s, accuracy=0.689, cost=1.37]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.88it/s, accuracy=0.355, cost=3.49]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 14, training avg loss 2.050010, training avg acc 0.556602\n",
      "epoch 14, testing avg loss 3.911178, testing avg acc 0.350626\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:12<00:00,  1.61it/s, accuracy=0.717, cost=1.27]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.93it/s, accuracy=0.344, cost=3.48]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 15, training avg loss 1.965698, training avg acc 0.571013\n",
      "epoch 15, testing avg loss 3.974208, testing avg acc 0.348345\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:11<00:00,  1.61it/s, accuracy=0.729, cost=1.2] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:14<00:00,  2.80it/s, accuracy=0.355, cost=3.6] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 16, training avg loss 1.887398, training avg acc 0.584290\n",
      "epoch 16, testing avg loss 4.065328, testing avg acc 0.345706\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:10<00:00,  1.61it/s, accuracy=0.748, cost=1.12]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:14<00:00,  2.81it/s, accuracy=0.339, cost=3.62]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 17, training avg loss 1.817722, training avg acc 0.596423\n",
      "epoch 17, testing avg loss 4.142617, testing avg acc 0.341623\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:21<00:00,  1.59it/s, accuracy=0.78, cost=1.04] \n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.92it/s, accuracy=0.328, cost=3.7] \n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 18, training avg loss 1.751493, training avg acc 0.608146\n",
      "epoch 18, testing avg loss 4.192155, testing avg acc 0.336698\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [16:22<00:00,  1.59it/s, accuracy=0.786, cost=0.977]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:13<00:00,  2.92it/s, accuracy=0.349, cost=3.68]\n",
      "minibatch loop:   0%|          | 0/1563 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 19, training avg loss 1.686881, training avg acc 0.619580\n",
      "epoch 19, testing avg loss 4.243670, testing avg acc 0.335427\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "minibatch loop: 100%|██████████| 1563/1563 [10:37<00:00,  2.45it/s, accuracy=0.794, cost=0.912]\n",
      "minibatch loop: 100%|██████████| 40/40 [00:06<00:00,  6.60it/s, accuracy=0.36, cost=3.77] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 20, training avg loss 1.626613, training avg acc 0.630753\n",
      "epoch 20, testing avg loss 4.310580, testing avg acc 0.334834\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "\n",
    "for e in range(epoch):\n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(train_X), batch_size), desc = 'minibatch loop')\n",
    "    train_loss, train_acc, test_loss, test_acc = [], [], [], []\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(train_X))\n",
    "        batch_x = pad_sequences(train_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(train_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y}\n",
    "        accuracy, loss, _ = sess.run([model.accuracy,model.cost,model.optimizer],\n",
    "                                    feed_dict = feed)\n",
    "        train_loss.append(loss)\n",
    "        train_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    \n",
    "    pbar = tqdm.tqdm(\n",
    "        range(0, len(test_X), batch_size), desc = 'minibatch loop')\n",
    "    for i in pbar:\n",
    "        index = min(i + batch_size, len(test_X))\n",
    "        batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "        batch_y = pad_sequences(test_Y[i : index], padding='post')\n",
    "        feed = {model.X: batch_x,\n",
    "                model.Y: batch_y,}\n",
    "        accuracy, loss = sess.run([model.accuracy,model.cost],\n",
    "                                    feed_dict = feed)\n",
    "\n",
    "        test_loss.append(loss)\n",
    "        test_acc.append(accuracy)\n",
    "        pbar.set_postfix(cost = loss, accuracy = accuracy)\n",
    "    \n",
    "    print('epoch %d, training avg loss %f, training avg acc %f'%(e+1,\n",
    "                                                                 np.mean(train_loss),np.mean(train_acc)))\n",
    "    print('epoch %d, testing avg loss %f, testing avg acc %f'%(e+1,\n",
    "                                                              np.mean(test_loss),np.mean(test_acc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensor2tensor.utils import bleu_hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 40/40 [00:14<00:00,  2.68it/s]\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for i in tqdm.tqdm(range(0, len(test_X), batch_size)):\n",
    "    index = min(i + batch_size, len(test_X))\n",
    "    batch_x = pad_sequences(test_X[i : index], padding='post')\n",
    "    feed = {model.X: batch_x}\n",
    "    p = sess.run(model.fast_result,feed_dict = feed)\n",
    "    result = []\n",
    "    for row in p:\n",
    "        result.append([i for i in row if i > 3])\n",
    "    results.extend(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "rights = []\n",
    "for r in test_Y:\n",
    "    rights.append([i for i in r if i > 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bleu_hook.compute_bleu(reference_corpus = rights,\n",
    "                       translation_corpus = results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
